package com.jstarcraft.ai.jsat.clustering;

import static com.jstarcraft.ai.jsat.utils.SystemInfo.LogicalCores;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.distributions.empirical.kernelfunc.GaussKF;
import com.jstarcraft.ai.jsat.distributions.empirical.kernelfunc.KernelFunction;
import com.jstarcraft.ai.jsat.distributions.multivariate.MetricKDE;
import com.jstarcraft.ai.jsat.distributions.multivariate.MultivariateKDE;
import com.jstarcraft.ai.jsat.distributions.multivariate.ProductKDE;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.linear.VecPaired;
import com.jstarcraft.ai.jsat.linear.distancemetrics.DistanceMetric;
import com.jstarcraft.ai.jsat.linear.distancemetrics.EuclideanDistance;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;

/**
 * The MeanShift algorithms performs clustering on a data set by letting the
 * data speak for itself and performing a mode search amongst the data set,
 * returning a cluster for each discovered mode. <br>
 * <br>
 * While not normally discussed in the context of Mean Shift, this
 * implementation has rudimentary outlier-removal, outliers will not be included
 * in the clustering. <br>
 * <br>
 * The mean shift requires a {@link MultivariateKDE} to run. Contrary to use in
 * density estimation, where the {@link KernelFunction} used has only a minor
 * impact on the results, it is highly recommended you use the {@link GaussKF}
 * for the MeanShift method. This is because of the large support and better
 * behaved derivative, which adds in the avoidance of oscillating convergence.
 * <br>
 * <br>
 * Implementation Note: This implementation does not snap the values to a grid.
 * This causes the prior noted oscillation in convergence.
 * 
 * @author Edward Raff
 */
public class MeanShift implements Clusterer {

    private static final long serialVersionUID = 4061491342362690455L;
    /**
     * The default number of {@link #getMaxIterations() } is
     * {@value #DefaultMaxIterations}
     */
    public static final int DefaultMaxIterations = 1000;
    /**
     * The default value of {@link #getScaleBandwidthFactor() } is
     * {@value #DefaultScaleBandwidthFactor}
     */
    public static final double DefaultScaleBandwidthFactor = 1.0;
    private MultivariateKDE mkde;
    private int maxIterations = DefaultMaxIterations;
    private double scaleBandwidthFactor = DefaultScaleBandwidthFactor;

    /**
     * Creates a new MeanShift clustering object using a {@link MetricKDE}, the
     * {@link GaussKF}, and the {@link EuclideanDistance}.
     */
    public MeanShift() {
        this(new EuclideanDistance());
    }

    /**
     * Creates a new MeanShift clustering object using a {@link MetricKDE} and the
     * {@link GaussKF}.
     * 
     * @param dm the distance metric to use
     */
    public MeanShift(DistanceMetric dm) {
        this(new MetricKDE(GaussKF.getInstance(), dm));
    }

    /**
     * Creates a new MeanShift clustering object. <br>
     * NOTE: {@link ProductKDE} does not currently support the functions needed to
     * work with MeanShift.
     * 
     * @param mkde the KDE to use in the clustering process.
     */
    public MeanShift(MultivariateKDE mkde) {
        this.mkde = mkde;
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public MeanShift(MeanShift toCopy) {
        this.mkde = toCopy.mkde.clone();
        this.maxIterations = toCopy.maxIterations;
        this.scaleBandwidthFactor = toCopy.scaleBandwidthFactor;
    }

    /**
     * Sets the maximum number of iterations the algorithm will go through,
     * terminating early if convergence has not occurred.
     * 
     * @param maxIterations the maximum number of iterations
     * @throws ArithmeticException if a value less than 1 is given
     */
    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0)
            throw new ArithmeticException("Invalid iteration count, " + maxIterations);
        this.maxIterations = maxIterations;
    }

    /**
     * Returns the maximum number of iterations the algorithm will go through,
     * terminating early if convergence has not occurred.
     * 
     * @return the maximum number of iterations
     */
    public int getMaxIterations() {
        return maxIterations;
    }

    /**
     * Sets the value by which the bandwidth of the {@link MultivariateKDE} will be
     * scaled by.
     * 
     * @param scaleBandwidthFactor the value to scale bandwidth by
     * @throws ArithmeticException if the value given is {@link Double#NaN NaN } or
     *                             {@link Double#POSITIVE_INFINITY infinity}
     */
    public void setScaleBandwidthFactor(double scaleBandwidthFactor) {
        if (Double.isNaN(scaleBandwidthFactor) || Double.isInfinite(scaleBandwidthFactor))
            throw new ArithmeticException("Invalid scale factor, " + scaleBandwidthFactor);
        this.scaleBandwidthFactor = scaleBandwidthFactor;
    }

    /**
     * Returns the value by which the bandwidth of the {@link MultivariateKDE} will
     * be scaled by.
     * 
     * @return the value to scale bandwidth by
     */
    public double getScaleBandwidthFactor() {
        return scaleBandwidthFactor;
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        try {
            if (designations == null || designations.length < dataSet.size())
                designations = new int[dataSet.size()];
            boolean[] converged = new boolean[dataSet.size()];
            Arrays.fill(converged, false);

            final KernelFunction k = mkde.getKernelFunction();
            mkde.setUsingData(dataSet, parallel);
            mkde.scaleBandwidth(scaleBandwidthFactor);

            Vec[] xit = new Vec[converged.length];
            for (int i = 0; i < xit.length; i++)
                xit[i] = dataSet.getDataPoint(i).getNumericalValues().clone();
            mainLoop(converged, xit, designations, k, parallel);

            assignmentStep(converged, xit, designations);

            return designations;
        } catch (InterruptedException ex) {
            Logger.getLogger(MeanShift.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        } catch (BrokenBarrierException ex) {
            Logger.getLogger(MeanShift.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        }
    }

    private void assignmentStep(boolean[] converged, Vec[] xit, int[] designations) {
        // We now repurpose the 'converged' array to indicate if the point has not yet
        // been asigned to a cluster

        // Loop through and asign clusters
        int curClusterID = 0;
        boolean progress = true;
        while (progress) {
            progress = false;
            int basePos = 0;// This will be the mode of our cluster
            while (basePos < converged.length && !converged[basePos])
                basePos++;
            for (int i = basePos; i < converged.length; i++) {
                if (!converged[i] || designations[i] == -1)
                    continue;// Already assigned
                progress = true;
                if (Math.abs(xit[basePos].pNormDist(2, xit[i])) < 1e-3) {
                    converged[i] = false;
                    designations[i] = curClusterID;
                }
            }

            curClusterID++;
        }
    }

    private void mainLoop(final boolean[] converged, final Vec[] xit, final int[] designations, final KernelFunction k, boolean parallel) throws InterruptedException, BrokenBarrierException {
        AtomicBoolean progress = new AtomicBoolean(true);

        int count = 0;
        /*
         * +1 b/c we have to wait for the worker threads, but we also want this calling
         * thread to wait with them. Hence, +1
         */
        CyclicBarrier barrier = new CyclicBarrier(LogicalCores + 1);

        final ThreadLocal<Vec> localScratch = ThreadLocal.withInitial(() -> new DenseVector(xit[0].length()));

        while (progress.get() && count++ < maxIterations) {
            progress.set(false);

            ParallelUtils.run(parallel, converged.length, (i) -> {
                if (converged[i])
                    return;
                progress.lazySet(true);
                convergenceStep(xit, i, converged, designations, localScratch.get(), k);
            });
        }

        // Fill b/c we may have bailed out due to maxIterations
        Arrays.fill(converged, true);
    }

    /**
     * Computes the meanShift of the point at the given index, and then updates the
     * vector array to indicate the movement towards the mode for the data point.
     * 
     * @param xit          the array of the current data point's positions
     * @param i            the index of the data point being considered
     * @param converged    the array used to indicate the convergence to a mode
     * @param designations the array to store value designations in
     * @param scratch      the vector to compute work with
     * @param k            the kernel function to use
     */
    private void convergenceStep(final Vec[] xit, int i, final boolean[] converged, final int[] designations, final Vec scratch, final KernelFunction k) {
        double denom = 0.0;
        Vec xCur = xit[i];
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> contrib = mkde.getNearbyRaw(xCur);

        if (contrib.size() == 1) {
            // If a point has no neighbors, it can not shift, and is its own mdoe - so we
            // mark it noise
            converged[i] = true;
            designations[i] = -1;
        } else {
            scratch.zeroOut();
            for (VecPaired<VecPaired<Vec, Integer>, Double> v : contrib) {
                double g = -k.kPrime(v.getPair());
                denom += g;
                scratch.mutableAdd(g, v);
            }
            scratch.mutableDivide(denom);

            if (Math.abs(scratch.pNormDist(2, xCur)) < 1e-5)
                converged[i] = true;

            scratch.copyTo(xCur);
        }
    }

    @Override
    public MeanShift clone() {
        return new MeanShift(this);
    }

}
